{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f407150e-d967-4dfc-943c-65a9506a71dd",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "id": "6bacb236-523c-44ab-8a32-d663f1d5d747",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import pickle\n",
    "from typing import List, Dict\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "697ff8d3-ec8e-4ebe-aeda-f9eabda3f915",
   "metadata": {},
   "outputs": [],
   "source": [
    "# These pickles are the outputs of the medbert repo: https://github.com/ZhiGroup/Med-BERT/blob/master/Pretraining%20Code/Readme.md\n",
    "PRETRAINED_TRAIN_PICKLE_PATH = '/sise/home/benshoho/projects/Med-BERT/Pretraining Code/Data Pre-processing Code/temp-try.bencs.train'\n",
    "MEDBERT_CODES_DICT_PATH = '/sise/home/benshoho/projects/Med-BERT/Pretraining Code/Data Pre-processing Code/temp-try.types'\n",
    "PRETRAINED_VALIDATION_PICKLE_PATH = '/sise/home/benshoho/projects/Med-BERT/Pretraining Code/Data Pre-processing Code/temp-try.bencs.valid'\n",
    "PRETRAINED_TEST_PICKLE_PATH = '/sise/home/benshoho/projects/Med-BERT/Pretraining Code/Data Pre-processing Code/temp-try.bencs.test'\n",
    "\n",
    "\n",
    "MEDBERT_OUTPUT_PICKLES_DIR = '/sise/home/benshoho/projects/Med-BERT/Fine-Tunning-Tutorials/data/mimic-iv'\n",
    "\n",
    "TARGET_DISEASE_IDS = {'157'} # CCS category code.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81d34335-f300-42a3-8850-3aa83a90c95d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open(MEDBERT_CODES_DICT_PATH, 'rb') as file:\n",
    "    loaded_data = pickle.load(file)\n",
    "loaded_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "id": "de36c57e-abbe-41cb-a96a-ae95ec19cc5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(PRETRAINED_TRAIN_PICKLE_PATH, 'rb') as f:\n",
    "    medbert_train_data = pickle.load(f)\n",
    "with open(PRETRAINED_VALIDATION_PICKLE_PATH, 'rb') as f:\n",
    "    medbert_validation_data = pickle.load(f)\n",
    "with open(PRETRAINED_TEST_PICKLE_PATH, 'rb') as f:\n",
    "    medbert_test_data = pickle.load(f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "id": "f3658eb5-2fe8-4c02-bb56-af03a117313f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(59118, 8445, 16890)"
      ]
     },
     "execution_count": 128,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(medbert_train_data), len(medbert_validation_data), len(medbert_test_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "568b598a-6aff-4ccd-9c72-2d4bdc61e056",
   "metadata": {},
   "source": [
    "### Convert pickle to df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd28ae15-a397-43d0-9066-bd718e66ba13",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = pd.DataFrame(medbert_train_data, columns= ['person_id', 'los', 'time_not_used', 'code', 'visits'])\n",
    "validation_df = pd.DataFrame(medbert_validation_data, columns= ['person_id', 'los', 'time_not_used', 'code', 'visits'])\n",
    "test_df = pd.DataFrame(medbert_test_data, columns= ['person_id', 'los', 'time_not_used', 'code', 'visits'])\n",
    "for x in (train_df, validation_df, test_df):\n",
    "    x.drop(columns=['los', 'time_not_used'], inplace=True)\n",
    "train_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d6f96c9-30e7-4db9-bd7e-f17451474c7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(MEDBERT_CODES_DICT_PATH, 'rb') as f:\n",
    "          code_to_id_dict = pickle.load(f)\n",
    "print(code_to_id_dict)\n",
    "\n",
    "def convert_codes_to_ids(codes: List[str], code_to_id_dict: Dict[str, int]):\n",
    "    converted_codes = []\n",
    "    for code in codes: \n",
    "        converted_codes.append(code_to_id_dict[str(code)])\n",
    "    return converted_codes\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "id": "e8142cff-2794-4309-a1cf-96ff1fd73c03",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[61]"
      ]
     },
     "execution_count": 132,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TARGET_DISEASE_IDS = convert_codes_to_ids(TARGET_DISEASE_IDS, code_to_id_dict)\n",
    "TARGET_DISEASE_IDS"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b66e1d3-9eb7-499b-b495-5b7ef1209002",
   "metadata": {},
   "source": [
    "\n",
    "## Convert to medbert format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a626213-1465-4e21-9118-4f0e5c8959c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "id": "6e5b5c4f-cf87-4094-9155-4be277267ae2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_sep_between_visits(row):\n",
    "    codes, visits = row.code, row.visits\n",
    "    new_codes = []\n",
    "    new_visits = []\n",
    "    \n",
    "    for i in range(len(codes)):\n",
    "        new_codes.append(codes[i])\n",
    "        new_visits.append(visits[i])\n",
    "        if i < len(codes) - 1 and visits[i] != visits[i + 1]:\n",
    "            new_codes.append('SEP')\n",
    "            new_visits.append('SEP')\n",
    "    new_codes.append('SEP')\n",
    "    new_visits.append('SEP')\n",
    "    assert len(new_codes) == len(new_visits)\n",
    "    return new_codes, new_visits\n",
    "\n",
    "for x in (train_df, validation_df, test_df):\n",
    "    x[['code', 'visits']] = x.apply(add_sep_between_visits, axis=1, result_type='expand')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "id": "0764fa47-0c05-415c-9452-aa43b23493d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "18654\n"
     ]
    }
   ],
   "source": [
    "def count_lists_with_fewer_seps(inner_list):\n",
    "    sep_count = inner_list.count('SEP')\n",
    "    if sep_count < 2:\n",
    "        return True\n",
    "    return False\n",
    "\n",
    "\n",
    "def count_lists_with_target_before_sep(inner_list):\n",
    "    found_sep = False\n",
    "\n",
    "    for item in inner_list:\n",
    "        if item == 'SEP':\n",
    "            found_sep = False\n",
    "            return False\n",
    "        if item in TARGET_DISEASE_IDS and not found_sep:\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "all_codes = list(train_df['code'])\n",
    "total_count = 0\n",
    "for codes_list in all_codes:\n",
    "    if count_lists_with_target_before_sep(codes_list):\n",
    "        total_count += 1\n",
    "    elif count_lists_with_fewer_seps(codes_list):\n",
    "        total_count += 1\n",
    "    \n",
    "print(len(all_codes) - total_count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaa9f7ec-80db-4f56-a79b-f05a2da56b63",
   "metadata": {},
   "outputs": [],
   "source": [
    "def target_disease_in_first_visit(row):\n",
    "    codes = row.code\n",
    "    visits_num = row.visits\n",
    "    indexes_to_remove = []\n",
    "    for code, visit in zip(codes, visits_num):\n",
    "        if code in TARGET_DISEASE_IDS and visit == 1:\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "for x in (train_df, validation_df, test_df):\n",
    "    mask = x.apply(target_disease_in_first_visit, axis=1)\n",
    "    x.drop(index=x[mask].index, inplace=True)\n",
    "    x.drop(x[x['visits'].apply(lambda x: x.count('SEP') < 2)].index, inplace=True)  # Remove rows with only one visit.\n",
    "\n",
    "train_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8943795-30c0-450a-84d8-91c3cbe88d1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df.shape, validation_df.shape, test_df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61563d8d-7b6f-4100-8bf4-44e5c2c83b0b",
   "metadata": {},
   "source": [
    "### To medbert pickle format for fine-tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "id": "2d1376b0-cf92-44fb-b9f4-02db2b7d4c4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_start_index_of_visit(visits, visit_num_to_predict):\n",
    "    return visits.index(visit_num_to_predict)\n",
    "\n",
    "def get_history(codes, visits, visit_num_to_predict):\n",
    "    index = find_start_index_of_visit(visits, visit_num_to_predict)\n",
    "    return codes[:index], visits[:index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "id": "83d470ca-fb38-4097-8fa8-ca6dd008e421",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([1, 2, 4, 'SEP', 5, 'SEP', 6, 10], [1, 1, 1, 1, 1, 1, 1, 'SEP'])"
      ]
     },
     "execution_count": 142,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_history([1, 2, 4, 'SEP', 5, 'SEP', 6, 10, 'SEP'], [1, 1, 1, 1, 1, 1, 1, 'SEP', 2, 2, 2, 2, 2, 2, 'SEP'], 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "id": "cee913f8-a220-4f21-80cc-dd643aa7b210",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "import random\n",
    "\n",
    "def has_target_disease(target_disease_ids: List[str], codes: List[str]):\n",
    "    # return True if at least one from target_disease_ids can be found in codes and its index in the codes.\n",
    "    for index, code in enumerate(codes):\n",
    "        if code in target_disease_ids:\n",
    "            if index == 0:\n",
    "                print(codes)\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "def random_negative_visit(visits: List[str]):\n",
    "    # get a random visit number. \n",
    "    temp_visits = set(visits)\n",
    "    # print(temp_visits)\n",
    "    temp_visits.remove('SEP')\n",
    "    temp_visits.remove(1) # because we need history of at least one visit. \n",
    "    random_visit_num = random.choice(list(temp_visits))\n",
    "    return random_visit_num # that's the index to predict. we need at least two visits (one for history and second to predict)\n",
    "\n",
    "def get_positive_visit(codes, visits): \n",
    "    if not codes:\n",
    "        print('Empty list of codes!!!!!!!!!!!!!!!!!!!!!')\n",
    "        return []\n",
    "    for index, code in enumerate(codes):\n",
    "        if code in TARGET_DISEASE_IDS:\n",
    "            return visits[index] # that's the index to predict. we need at least two visits (one for history and second to predict)\n",
    "    print('Index was not found!')\n",
    "    return []\n",
    "\n",
    "def preprocess_patient_records(codes, visits, was_target_found: bool):\n",
    "    if was_target_found:\n",
    "        visit_num = get_positive_visit(codes, visits)\n",
    "    else:\n",
    "        visit_num = random_negative_visit(visits)\n",
    "    return get_history(codes, visits, visit_num)\n",
    "    \n",
    "def preprocess_patient_data(row):\n",
    "    person_id = row['person_id']\n",
    "    codes = row['code']\n",
    "    visits = row['visits']\n",
    "    assert len(codes) == len(visits)\n",
    "    classification_binary_label = has_target_disease(TARGET_DISEASE_IDS, codes)\n",
    "    #try:\n",
    "    codes, visits = preprocess_patient_records(codes, visits, classification_binary_label)\n",
    "    # except: \n",
    "    #     print(f'person_id={person_id}')\n",
    "    assert len(codes) == len(visits)\n",
    "\n",
    "    return codes, visits, 1 if classification_binary_label else 0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6607c8f-4fb8-4308-a07d-d518efa52ff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for x in (train_df, validation_df, test_df):\n",
    "    x[['code', 'visits', 'label']] = x.apply(preprocess_patient_data, axis=1, result_type='expand')\n",
    "train_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2be734cf-3e5a-4c7c-af82-02dc2ba58a55",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df['label'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab9a17d6-afad-45a3-bcb4-f2ea304ef4f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df['label'].value_counts()[0] / train_df['label'].value_counts()[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e314a375-f328-4d1a-8592-3b9cc1be3d1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_df['label'].value_counts()[0] / validation_df['label'].value_counts()[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb8deb57-6156-48d5-8923-b656cde2f943",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df['label'].value_counts()[0] / test_df['label'].value_counts()[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0af11842-35f5-42e3-ba8d-4fb82025c117",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df[train_df['label'] == 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58cc5184-3622-4b88-adac-c02dff9ddda6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "Counter(train_df['label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "420a8430-4bdf-4f87-914e-03052af38840",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_sep(row):\n",
    "    codes = row.code\n",
    "    visits_num = row.visits\n",
    "    indexes_to_remove = []\n",
    "    assert len(codes) == len(visits_num)\n",
    "    for index, code in enumerate(codes):\n",
    "        if code == 'SEP':\n",
    "            indexes_to_remove.append(index)\n",
    "    assert len(codes) == len(visits_num)\n",
    "    codes = [code for i, code in enumerate(codes) if i not in indexes_to_remove]\n",
    "    visits_num = [num for i, num in enumerate(visits_num) if i not in indexes_to_remove]\n",
    "    assert len(codes) == len(visits_num)\n",
    "    return codes, visits_num\n",
    "\n",
    "for x in (train_df, validation_df, test_df):\n",
    "    x[['code', 'visits']] = x.apply(filter_sep, axis=1, result_type='expand')\n",
    "\n",
    "train_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ba4b18f-d6cc-40b3-92b3-c65fd072b190",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df[train_df['person_id'] == 14339711]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cab0ea7a-b9af-4e63-a8ff-ae3c45030224",
   "metadata": {},
   "source": [
    "### To pickles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "id": "c286362e-f43d-4ab0-b4e8-d2afcedd4e0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_df_to_pickle(df: pd.DataFrame, pickle_output_dir: str, df_type: str, disease_name: str):\n",
    "    # df with columns: person_id, code\n",
    "    # Create a list to store patient records\n",
    "    patient_records = []\n",
    "\n",
    "    # Iterate over each row in the DataFrame\n",
    "    for index, row in df.iterrows():\n",
    "        # Extract the necessary information from the row\n",
    "        pt_id = row['person_id']\n",
    "        label = row['label']\n",
    "        seq_list = row['code']\n",
    "        segment_list = row['visits']\n",
    "        # print(seq_list)\n",
    "        # print(segment_list)\n",
    "        assert len(seq_list) == len(segment_list)\n",
    "        \n",
    "        # Create a patient record as a sublist\n",
    "        patient_record = [pt_id, label, seq_list, segment_list]\n",
    "        # Append the patient record to the list of patient records\n",
    "        patient_records.append(patient_record)\n",
    "\n",
    "    # Write the list of patient records to a pickle file\n",
    "    output_pickle_path = f'{pickle_output_dir}/{disease_name}_{df_type}.pickle'\n",
    "    with open(output_pickle_path, 'wb') as file:\n",
    "        pickle.dump(patient_records, file)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "id": "590eaca9-8caa-4046-a8f6-8ae030314d79",
   "metadata": {},
   "outputs": [],
   "source": [
    "for current_df, current_df_type in zip([train_df, validation_df, test_df], ['train', 'validation', 'test']):\n",
    "    write_df_to_pickle(current_df, MEDBERT_OUTPUT_PICKLES_DIR, current_df_type, disease_name='chronic_kidney_disease')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab017452-bb93-4af6-8148-3d44b78e98e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "id": "ab7202dc-efc4-4988-b10b-7da481827005",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/sise/home/benshoho/projects/Med-BERT/Fine-Tunning-Tutorials/data/mimic-iv'"
      ]
     },
     "execution_count": 130,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MEDBERT_OUTPUT_PICKLES_DIR"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b9f629b-94e8-4556-8361-a8295012f5b3",
   "metadata": {},
   "source": [
    "## Represent as icd10 description after aggegation instead the aggegated code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "id": "36c3df23-22d3-4cce-8111-0a23aac1620a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "\n",
    "def get_icd10_description(icd_code):\n",
    "    print(f'sending http request for icd_code={icd_code}')\n",
    "    base_url = \"http://icd10api.com/\"\n",
    "    params = {\n",
    "        \"code\": icd_code,\n",
    "        \"desc\": \"short\",\n",
    "        \"r\": \"json\"\n",
    "    }\n",
    "\n",
    "    try:\n",
    "        response = requests.get(base_url, params=params)\n",
    "        response.raise_for_status()  # Raise an exception for unsuccessful responses\n",
    "        data = response.json()\n",
    "        description = data.get(\"Description\")\n",
    "        return description\n",
    "    except Exception as e:\n",
    "        print(f\"Error occurred during API request: {e}\")\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "id": "162805ef-ba2f-446c-bbc0-0a537bc0653f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import requests\n",
    "\n",
    "id_to_aggregated_code_dict = {v:k for k,v in code_to_id_dict.items()}\n",
    "\n",
    "icd_aggegator_df = pd.read_csv('/sise/home/benshoho/projects/feature extraction/ccs_dx_icd10cm_2018_1.csv') # ccs aggegations.\n",
    "icd_aggegator_df.columns = icd_aggegator_df.columns.str.replace(\"'\", \"\")\n",
    "icd_aggegator_df = icd_aggegator_df.applymap(lambda x: x.replace(\"'\", \"\"))\n",
    "\n",
    "icd10_ccs_mapping = icd_aggegator_df.set_index('CCS CATEGORY')['CCS CATEGORY DESCRIPTION'].to_dict() # for example:'1': 'Tuberculosis',\n",
    "\n",
    "def from_medbert_code_to_description(code):\n",
    "    aggregated_code = id_to_aggregated_code_dict[code]\n",
    "    if aggregated_code in icd10_ccs_mapping:\n",
    "        text_description = icd10_ccs_mapping[aggregated_code]\n",
    "    else:\n",
    "        # in case of missing description we get the description from icd10api. \n",
    "        text_description = get_icd10_description(aggregated_code)\n",
    "    return text_description\n",
    "\n",
    "def from_medbert_codes_to_description(codes):\n",
    "    return [from_medbert_code_to_description(code) for code in codes]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1a124d8-7483-4071-8c9d-65190b3704a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from_medbert_codes_to_description([1, 44, 33, 89, 44])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "764efaf4-13c7-4963-b4e5-d2c5d0f2caee",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_desc = set(icd10_ccs_mapping.values())\n",
    "for data in (train_with_description_df, validation_with_description_df, test_with_description_df):\n",
    "    data_codes =  data['code']\n",
    "    for c in data_codes:\n",
    "        all_desc = all_desc.union(c)\n",
    "\n",
    "import pickle\n",
    "with open('/sise/home/benshoho/projects/Med-BERT/Pretraining Code/Data Pre-processing Code/mimic_iv_descriptions_set.types', \"wb\") as pickle_file:\n",
    "    pickle.dump(all_desc, pickle_file)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aa8e938-3b23-4196-aa4f-d485664f9fb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_with_description_df, validation_with_description_df, test_with_description_df = train_df.copy(), validation_df.copy(), test_df.copy()\n",
    "for x in (train_with_description_df, validation_with_description_df, test_with_description_df):\n",
    "    x['code'] = x['code'].apply(from_medbert_codes_to_description)\n",
    "train_with_description_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad0d8145-4fe6-49fe-92ea-a3b03932e6c3",
   "metadata": {},
   "source": [
    "### save new pickles with text description instead of numbers. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "id": "1a0f9123-190a-441e-8fa9-ba556a871a2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for current_df, current_df_type in zip([train_with_description_df, validation_with_description_df, test_with_description_df], ['train', 'validation', 'test']):\n",
    "    write_df_to_pickle(current_df, MEDBERT_OUTPUT_PICKLES_DIR, current_df_type, disease_name='chronic_kidney_disease_descriptions')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
